{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "5d904dee",
   "metadata": {},
   "source": [
    "# Use MultKAN\n",
    "\n",
    "In this example, we will show how to use MultKAN"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "94056ef6",
   "metadata": {},
   "source": [
    "intialize model and create dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0a59179d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from kan.MultKAN import MultKAN\n",
    "from kan.utils import create_dataset\n",
    "import torch\n",
    "\n",
    "torch.set_default_dtype(torch.float64)\n",
    "torch.use_deterministic_algorithms(True)\n",
    "\n",
    "seed = 0\n",
    "torch.manual_seed(seed)\n",
    "\n",
    "# initialize KAN with G=5. Setting the hyperparameters below is the best setup if you care about speed and never use the symbolic front\n",
    "# symblic_enabled = False will ignore the symbolic part to avoid slowdown due to the for loops in the symbolic part\n",
    "# save_plot_data = False will skip saving intermediate activations\n",
    "# auto_save = False will skip saving models\n",
    "model = MultKAN(width=[2,1,1], grid=5, k=3, symbolic_enabled=False, save_plot_data=False, auto_save=False)\n",
    "\n",
    "# create dataset\n",
    "f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)\n",
    "dataset = create_dataset(f, n_var=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb1f817e",
   "metadata": {},
   "source": [
    "Train KAN (grid=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a87b97b0",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train loss: 5.90e-03 | test loss: 6.09e-03 | reg: 0.00e+00 : 100%|█| 100/100 [00:11<00:00,  8.56it/s\n"
     ]
    }
   ],
   "source": [
    "model.train(dataset, opt=\"LBFGS\", steps=100);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3f1cfc9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save_plot_data = True\n",
    "model(dataset['train_input'])\n",
    "model = model.refine(10)\n",
    "model.save_plot_data = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f6900184",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train loss: 2.90e-04 | test loss: 3.05e-04 | reg: 0.00e+00 : 100%|█| 100/100 [00:12<00:00,  8.05it/s\n"
     ]
    }
   ],
   "source": [
    "model.train(dataset, opt=\"LBFGS\", steps=100);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c3a785ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save_plot_data = True\n",
    "model(dataset['train_input'])\n",
    "model = model.refine(20)\n",
    "model.save_plot_data = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "10d710ec",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train loss: 1.61e-05 | test loss: 1.73e-05 | reg: 0.00e+00 : 100%|█| 100/100 [00:17<00:00,  5.64it/s\n"
     ]
    }
   ],
   "source": [
    "model.train(dataset, opt=\"LBFGS\", steps=100);"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
